{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.01658978,  0.01415308, -0.03952613, -0.00342648], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAApBElEQVR4nO3df3RU9Z3/8ddMfgw/wkwMkExSEkShYIRgCxhmbS1dUgJGK2s8Ry0r6HLgyCaeQizFdK2K7TEu9qw/ugrne3ZX3HOktPYrulLBxiCh1ogYSfmlqbC0wcIkFE4yIZohmfl8//DL7I4iyYQw80nyfJxzz8ncz3vuvO/nYPLy/hqHMcYIAADAIs5ENwAAAPB5BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYJ2EBpRnnnlGl19+uYYNG6bCwkK9++67iWwHAABYImEB5Ze//KUqKir00EMP6f3339f06dNVXFyslpaWRLUEAAAs4UjUlwUWFhZq1qxZ+td//VdJUjgcVm5uru69917df//9iWgJAABYIjkRH3r27FnV19ersrIyss7pdKqoqEh1dXVfqA8GgwoGg5HX4XBYp0+f1ujRo+VwOOLSMwAAuDjGGLW3tysnJ0dO54VP4iQkoPz1r39VKBRSVlZW1PqsrCx9+OGHX6ivqqrS2rVr49UeAAC4hI4dO6Zx48ZdsCYhASVWlZWVqqioiLxua2tTXl6ejh07JrfbncDOAABAbwUCAeXm5mrUqFE91iYkoIwZM0ZJSUlqbm6OWt/c3Cyv1/uFepfLJZfL9YX1brebgAIAwADTm8szEnIXT2pqqmbMmKGamprIunA4rJqaGvl8vkS0BAAALJKwUzwVFRVasmSJZs6cqWuvvVZPPvmkOjo6dPfddyeqJQAAYImEBZTbbrtNJ0+e1IMPPii/369rrrlG27dv/8KFswAAYOhJ2HNQLkYgEJDH41FbWxvXoAAAMEDE8veb7+IBAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALBOvweUhx9+WA6HI2qZMmVKZLyzs1NlZWUaPXq00tLSVFpaqubm5v5uAwAADGCX5AjK1VdfrRMnTkSWt956KzK2atUqvfrqq3rxxRdVW1ur48eP65ZbbrkUbQAAgAEq+ZJsNDlZXq/3C+vb2tr07//+79q0aZP+9m//VpL03HPP6aqrrtI777yj2bNnX4p2AADAAHNJjqB89NFHysnJ0RVXXKFFixapqalJklRfX6+uri4VFRVFaqdMmaK8vDzV1dV96faCwaACgUDUAgAABq9+DyiFhYXauHGjtm/frvXr1+vo0aP65je/qfb2dvn9fqWmpio9PT3qPVlZWfL7/V+6zaqqKnk8nsiSm5vb320DAACL9PspngULFkR+LigoUGFhocaPH69f/epXGj58eJ+2WVlZqYqKisjrQCBASAEAYBC75LcZp6en66tf/aoOHz4sr9ers2fPqrW1Naqmubn5vNesnONyueR2u6MWAAAweF3ygHLmzBkdOXJE2dnZmjFjhlJSUlRTUxMZb2xsVFNTk3w+36VuBQAADBD9fornBz/4gW666SaNHz9ex48f10MPPaSkpCTdcccd8ng8Wrp0qSoqKpSRkSG32617771XPp+PO3gAAEBEvweUjz/+WHfccYdOnTqlsWPH6hvf+IbeeecdjR07VpL0xBNPyOl0qrS0VMFgUMXFxXr22Wf7uw0AADCAOYwxJtFNxCoQCMjj8aitrY3rUQAAGCBi+fvNd/EAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKwTc0DZtWuXbrrpJuXk5MjhcOjll1+OGjfG6MEHH1R2draGDx+uoqIiffTRR1E1p0+f1qJFi+R2u5Wenq6lS5fqzJkzF7UjAABg8Ig5oHR0dGj69Ol65plnzju+bt06Pf3009qwYYN2796tkSNHqri4WJ2dnZGaRYsW6eDBg6qurtbWrVu1a9cuLV++vO97AQAABhWHMcb0+c0Oh7Zs2aKFCxdK+uzoSU5Oju677z794Ac/kCS1tbUpKytLGzdu1O23364PPvhA+fn52rNnj2bOnClJ2r59u2644QZ9/PHHysnJ6fFzA4GAPB6P2tra5Ha7+9o+AACIo1j+fvfrNShHjx6V3+9XUVFRZJ3H41FhYaHq6uokSXV1dUpPT4+EE0kqKiqS0+nU7t27z7vdYDCoQCAQtQAAgMGrXwOK3++XJGVlZUWtz8rKioz5/X5lZmZGjScnJysjIyNS83lVVVXyeDyRJTc3tz/bBgAAlhkQd/FUVlaqra0tshw7dizRLQEAgEuoXwOK1+uVJDU3N0etb25ujox5vV61tLREjXd3d+v06dORms9zuVxyu91RCwAAGLz6NaBMmDBBXq9XNTU1kXWBQEC7d++Wz+eTJPl8PrW2tqq+vj5Ss2PHDoXDYRUWFvZnOwAAYIBKjvUNZ86c0eHDhyOvjx49qoaGBmVkZCgvL08rV67UT3/6U02aNEkTJkzQj3/8Y+Xk5ETu9Lnqqqs0f/58LVu2TBs2bFBXV5fKy8t1++239+oOHgAAMPjFHFDee+89ffvb3468rqiokCQtWbJEGzdu1A9/+EN1dHRo+fLlam1t1Te+8Q1t375dw4YNi7znhRdeUHl5uebOnSun06nS0lI9/fTT/bA7AABgMLio56AkCs9BAQBg4EnYc1AAAAD6AwEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1Yg4ou3bt0k033aScnBw5HA69/PLLUeN33XWXHA5H1DJ//vyomtOnT2vRokVyu91KT0/X0qVLdebMmYvaEQAAMHjEHFA6Ojo0ffp0PfPMM19aM3/+fJ04cSKy/OIXv4gaX7RokQ4ePKjq6mpt3bpVu3bt0vLly2PvHgAADErJsb5hwYIFWrBgwQVrXC6XvF7vecc++OADbd++XXv27NHMmTMlST//+c91ww036Gc/+5lycnJibQkAAAwyl+QalJ07dyozM1OTJ0/WihUrdOrUqchYXV2d0tPTI+FEkoqKiuR0OrV79+7zbi8YDCoQCEQtAABg8Or3gDJ//nz953/+p2pqavTP//zPqq2t1YIFCxQKhSRJfr9fmZmZUe9JTk5WRkaG/H7/ebdZVVUlj8cTWXJzc/u7bQAAYJGYT/H05Pbbb4/8PG3aNBUUFOjKK6/Uzp07NXfu3D5ts7KyUhUVFZHXgUCAkAIAwCB2yW8zvuKKKzRmzBgdPnxYkuT1etXS0hJV093drdOnT3/pdSsul0tutztqAQAAg9clDygff/yxTp06pezsbEmSz+dTa2ur6uvrIzU7duxQOBxWYWHhpW4HAAAMADGf4jlz5kzkaIgkHT16VA0NDcrIyFBGRobWrl2r0tJSeb1eHTlyRD/84Q81ceJEFRcXS5KuuuoqzZ8/X8uWLdOGDRvU1dWl8vJy3X777dzBAwAAJEkOY4yJ5Q07d+7Ut7/97S+sX7JkidavX6+FCxdq7969am1tVU5OjubNm6ef/OQnysrKitSePn1a5eXlevXVV+V0OlVaWqqnn35aaWlpveohEAjI4/Gora2N0z0AAAwQsfz9jjmg2ICAAgDAwBPL32++iwcAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArBPzlwUCQF8YY3Tkjf+jcPfZC9Zd/q3FSh3hiVNXAGxFQAEQN4FjBxXq6rxgTSjYITPcLYfDEaeuANiIUzwArBLuuvARFgBDAwEFgFVC3cFEtwDAAgQUAFbhCAoAiYACwDI9XUQLYGggoACwSphTPABEQAFgGY6gAJAIKAAsQ0ABIBFQAFiGgAJAIqAAsEyoKyjJJLoNAAlGQAEQN5nT5vZY03LwTckQUIChjoACIG6SXSN7rAmHuuPQCQDbEVAAxI0zOTXRLQAYIAgoAOKGgAKgtwgoAOLGmUJAAdA7BBQAcZOU7Ep0CwAGCAIKgLhxElAA9BIBBUDccIoHQG8RUADEDad4APRWTAGlqqpKs2bN0qhRo5SZmamFCxeqsbExqqazs1NlZWUaPXq00tLSVFpaqubm5qiapqYmlZSUaMSIEcrMzNTq1avV3c2zD4DBjiMoAHorpoBSW1ursrIyvfPOO6qurlZXV5fmzZunjo6OSM2qVav06quv6sUXX1Rtba2OHz+uW265JTIeCoVUUlKis2fP6u2339bzzz+vjRs36sEHH+y/vQJgHYfDIWcSAQVA7ziM6fszpU+ePKnMzEzV1tbq+uuvV1tbm8aOHatNmzbp1ltvlSR9+OGHuuqqq1RXV6fZs2dr27ZtuvHGG3X8+HFlZWVJkjZs2KA1a9bo5MmTSk3t+RdYIBCQx+NRW1ub3G53X9sHEGfdwQ7t3bjqwkUOp75+1xNKSh0en6YAxE0sf78v6hqUtrY2SVJGRoYkqb6+Xl1dXSoqKorUTJkyRXl5eaqrq5Mk1dXVadq0aZFwIknFxcUKBAI6ePDgeT8nGAwqEAhELQAGrxDfaAwMeX0OKOFwWCtXrtR1112nqVOnSpL8fr9SU1OVnp4eVZuVlSW/3x+p+d/h5Nz4ubHzqaqqksfjiSy5ubl9bRvAABDuCia6BQAJ1ueAUlZWpgMHDmjz5s392c95VVZWqq2tLbIcO3bskn8mgEQxCncTUIChLrkvbyovL9fWrVu1a9cujRs3LrLe6/Xq7Nmzam1tjTqK0tzcLK/XG6l59913o7Z37i6fczWf53K55HJxeyIwVIS6OMUDDHUxHUExxqi8vFxbtmzRjh07NGHChKjxGTNmKCUlRTU1NZF1jY2Nampqks/nkyT5fD7t379fLS0tkZrq6mq53W7l5+dfzL4AGCTCXIMCDHkxHUEpKyvTpk2b9Morr2jUqFGRa0Y8Ho+GDx8uj8ejpUuXqqKiQhkZGXK73br33nvl8/k0e/ZsSdK8efOUn5+vO++8U+vWrZPf79cDDzygsrIyjpIAkERAARBjQFm/fr0kac6cOVHrn3vuOd11112SpCeeeEJOp1OlpaUKBoMqLi7Ws88+G6lNSkrS1q1btWLFCvl8Po0cOVJLlizRI488cnF7AmBwMAQUABf5HJRE4TkowMDUq+egyKHLv7VYY6dcF5eeAMRP3J6DAgCxcSjJNbLHqmD7qTj0AsBmBBQAceNMSlHGlTN7qDI6fXh3XPoBYC8CCoD4cUjOZL6PB0DPCCgA4sghZ1JKopsAMAAQUADEFUdQAPQGAQVA3DgcDgIKgF4hoACIK2cyp3gA9IyAAiCOOIICoHcIKADixyE5CCgAeoGAAiCOHEpKIqAA6BkBBUBccQ0KgN4goACIG4fDIYczqedCY2TC4UvfEABrEVAAWMfIKBzqSnQbABKIgALAPoaAAgx1BBQA9jFG4e6zie4CQAIRUABYx8jIdHMEBRjKCCgA7MMpHmDII6AAsI8xCoe6E90FgAQioACwjpGR4QgKMKQRUADYxxiFuQYFGNIIKACsY7gGBRjyCCgA4irZNVLD0rMvWBPuPquOlqNx6giAjQgoAOIqyTVcw9KzLlhjQl3qbPXHqSMANiKgAIgrh8MppzM50W0AsBwBBUB8OZxyJBFQAFwYAQVAXDkIKAB6gYACIL4cDjkJKAB6QEABEFcOh1MOrkEB0AMCCoD4cjrlSE5JdBcALEdAARBX3MUDoDcIKADiysE1KAB6IaaAUlVVpVmzZmnUqFHKzMzUwoUL1djYGFUzZ84cORyOqOWee+6JqmlqalJJSYlGjBihzMxMrV69Wt3dfHMpMCRwFw+AXojpt0Rtba3Kyso0a9YsdXd360c/+pHmzZunQ4cOaeTIkZG6ZcuW6ZFHHom8HjFiROTnUCikkpISeb1evf322zpx4oQWL16slJQUPfroo/2wSwBs5nA4JEfP/29kTFgmHJLDmRSHrgDYJqaAsn379qjXGzduVGZmpurr63X99ddH1o8YMUJer/e82/jtb3+rQ4cO6Y033lBWVpauueYa/eQnP9GaNWv08MMPKzU1tQ+7AWCwMWECCjCUXdQ1KG1tbZKkjIyMqPUvvPCCxowZo6lTp6qyslKffPJJZKyurk7Tpk1TVtb/fBdHcXGxAoGADh48eN7PCQaDCgQCUQuAwc2YkMIhTv0CQ1WfTwSHw2GtXLlS1113naZOnRpZ/73vfU/jx49XTk6O9u3bpzVr1qixsVEvvfSSJMnv90eFE0mR137/+b8crKqqSmvXru1rqwAGoM+OoBBQgKGqzwGlrKxMBw4c0FtvvRW1fvny5ZGfp02bpuzsbM2dO1dHjhzRlVde2afPqqysVEVFReR1IBBQbm5u3xoHMDCEQzKhUKK7AJAgfTrFU15erq1bt+rNN9/UuHHjLlhbWFgoSTp8+LAkyev1qrm5Oarm3Osvu27F5XLJ7XZHLQAGt3MXyQIYmmIKKMYYlZeXa8uWLdqxY4cmTJjQ43saGhokSdnZ2ZIkn8+n/fv3q6WlJVJTXV0tt9ut/Pz8WNoBMIiZcEhhTvEAQ1ZMp3jKysq0adMmvfLKKxo1alTkmhGPx6Phw4fryJEj2rRpk2644QaNHj1a+/bt06pVq3T99deroKBAkjRv3jzl5+frzjvv1Lp16+T3+/XAAw+orKxMLper//cQwIBkTJhTPMAQFtMRlPXr16utrU1z5sxRdnZ2ZPnlL38pSUpNTdUbb7yhefPmacqUKbrvvvtUWlqqV199NbKNpKQkbd26VUlJSfL5fPr7v/97LV68OOq5KQAgLpIFhrSYjqAYYy44npubq9ra2h63M378eL322muxfDSAIcaEQ1yDAgxhfBcPgLgbOXa8hl2Wc8GaT1v9OtN8JE4dAbANAQVA3CUlu+RMTrlwkQnzoDZgCCOgAIg7hzNJDgePsAfw5QgoAOLO4UySw8mvHwBfjt8QAOLus4DCERQAX46AAiDuHEkEFAAXRkABEHcOZ7IcDn79APhy/IYAEHec4gHQEwIKgLgjoADoCQEFQNxxFw+AnvAbAkDcxXIEpaev2AAwOBFQAMSdw+HoVZ0JdUsmfIm7AWAjAgoAa5lwN0dQgCGKgALAWuFQtwxHUIAhiYACwFom1C2FQ4luA0ACEFAAWOuzUzwcQQGGIgIKAGuFQ90yYQIKMBQRUABYi7t4gKGLgALAWiYc4hQPMEQlJ7oBAAOPMUah0MVdvBoO93z7cKi7S91dXXJ2d/f5c5KSknr93BUA9iCgAIjZxx9/rCuuuOKitrH6Np++e91XlXSBR97verNaVcse1rGWQJ8+IykpSe3t7UpJSelrmwAShIACoE+6L+KohiTt+sOfdH1BnkZ7RnxpzTUTvRrrGa6jx0/36TPCXGALDFgEFAAJEewKKWyMQiZJzcHL9UnYLckoLalVWal/EmdlgKGNgAIgIc52hxQ20vuB7yjQPUZdZpgko1Rnp1rOjlfBqNpEtwgggQgoABKiOyTtbi1RyqhcSf9zuCQYHqnjwUlyyOjqtN8lrkEACcVtxgASYsqMFUpOm6T/HU7OMXLq4+BkHf20IP6NAbACAQVAYjgcPdz+69D5wguAoYGAAgAArENAAQAA1iGgAEiIN7b/TK2nDks63xNljbyp/63Lh++Pd1sALBFTQFm/fr0KCgrkdrvldrvl8/m0bdu2yHhnZ6fKyso0evRopaWlqbS0VM3NzVHbaGpqUklJiUaMGKHMzEytXr36oh/4BGDg6fjkjK4d9X/lST6pZEdQUlhSWCmOTmWm/lnXjHpDSY6Le5w+gIErptuMx40bp8cee0yTJk2SMUbPP/+8br75Zu3du1dXX321Vq1apd/85jd68cUX5fF4VF5erltuuUW///3vJUmhUEglJSXyer16++23deLECS1evFgpKSl69NFHL8kOArBTKGy04/3/VobnZ/pL5ySdCV0mh4xGJZ/SuGF/1LH/X3eytSOhfQJIDIcxpudv7LqAjIwMPf7447r11ls1duxYbdq0Sbfeeqsk6cMPP9RVV12luro6zZ49W9u2bdONN96o48ePKysrS5K0YcMGrVmzRidPnlRqamqvPjMQCMjj8eiuu+7q9XsA9J+Ojg698MILiW6jRw6HQ0uXLpXzAt/3AyB+zp49q40bN6qtrU1ut/uCtX1+UFsoFNKLL76ojo4O+Xw+1dfXq6urS0VFRZGaKVOmKC8vLxJQ6urqNG3atEg4kaTi4mKtWLFCBw8e1Ne+9rXzflYwGFQwGIy8DgQ+++KwO++8U2lpaX3dBQB91NzcPGACyt13363kZJ5JCdjgzJkz2rhxY69qY/6vdv/+/fL5fOrs7FRaWpq2bNmi/Px8NTQ0KDU1Venp6VH1WVlZ8vv9kiS/3x8VTs6Nnxv7MlVVVVq7du0X1s+cObPHBAag/x07dqznIkvMmjWLbzMGLHHuAENvxHzcc/LkyWpoaNDu3bu1YsUKLVmyRIcOHYp1MzGprKxUW1tbZBlIvxwBAEDsYj6CkpqaqokTJ0qSZsyYoT179uipp57SbbfdprNnz6q1tTXqKEpzc7O8Xq8kyev16t13343a3rm7fM7VnI/L5ZLL5Yq1VQAAMEBd9JVj4XBYwWBQM2bMUEpKimpqaiJjjY2Nampqks/nkyT5fD7t379fLS0tkZrq6mq53W7l5+dfbCsAAGCQiOkISmVlpRYsWKC8vDy1t7dr06ZN2rlzp15//XV5PB4tXbpUFRUVysjIkNvt1r333iufz6fZs2dLkubNm6f8/HzdeeedWrdunfx+vx544AGVlZVxhAQAAETEFFBaWlq0ePFinThxQh6PRwUFBXr99df1ne98R5L0xBNPyOl0qrS0VMFgUMXFxXr22Wcj709KStLWrVu1YsUK+Xw+jRw5UkuWLNEjjzzSv3sFAAAGtIt+DkoinHsOSm/uowbQ/44dO6a8vLxEt9Ejp9Opzs5O7uIBLBHL32+eXgQAAKxDQAEAANYhoAAAAOsQUAAAgHX4ggoAMRs+fLgWLlyY6DZ65HQ65XA4Et0GgD4goACI2ZgxY7Rly5ZEtwFgEOMUDwAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYJ2YAsr69etVUFAgt9stt9stn8+nbdu2RcbnzJkjh8MRtdxzzz1R22hqalJJSYlGjBihzMxMrV69Wt3d3f2zNwAAYFBIjqV43LhxeuyxxzRp0iQZY/T888/r5ptv1t69e3X11VdLkpYtW6ZHHnkk8p4RI0ZEfg6FQiopKZHX69Xbb7+tEydOaPHixUpJSdGjjz7aT7sEAAAGOocxxlzMBjIyMvT4449r6dKlmjNnjq655ho9+eST563dtm2bbrzxRh0/flxZWVmSpA0bNmjNmjU6efKkUlNTe/WZgUBAHo9HbW1tcrvdF9M+AACIk1j+fvf5GpRQKKTNmzero6NDPp8vsv6FF17QmDFjNHXqVFVWVuqTTz6JjNXV1WnatGmRcCJJxcXFCgQCOnjw4Jd+VjAYVCAQiFoAAMDgFdMpHknav3+/fD6fOjs7lZaWpi1btig/P1+S9L3vfU/jx49XTk6O9u3bpzVr1qixsVEvvfSSJMnv90eFE0mR136//0s/s6qqSmvXro21VQAAMEDFHFAmT56shoYGtbW16de//rWWLFmi2tpa5efna/ny5ZG6adOmKTs7W3PnztWRI0d05ZVX9rnJyspKVVRURF4HAgHl5ub2eXsAAMBuMZ/iSU1N1cSJEzVjxgxVVVVp+vTpeuqpp85bW1hYKEk6fPiwJMnr9aq5uTmq5txrr9f7pZ/pcrkidw6dWwAAwOB10c9BCYfDCgaD5x1raGiQJGVnZ0uSfD6f9u/fr5aWlkhNdXW13G535DQRAABATKd4KisrtWDBAuXl5am9vV2bNm3Szp079frrr+vIkSPatGmTbrjhBo0ePVr79u3TqlWrdP3116ugoECSNG/ePOXn5+vOO+/UunXr5Pf79cADD6isrEwul+uS7CAAABh4YgooLS0tWrx4sU6cOCGPx6OCggK9/vrr+s53vqNjx47pjTfe0JNPPqmOjg7l5uaqtLRUDzzwQOT9SUlJ2rp1q1asWCGfz6eRI0dqyZIlUc9NAQAAuOjnoCQCz0EBAGDgictzUAAAAC4VAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYJ3kRDfQF8YYSVIgEEhwJwAAoLfO/d0+93f8QgZkQGlvb5ck5ebmJrgTAAAQq/b2dnk8ngvWOExvYoxlwuGwGhsblZ+fr2PHjsntdie6pQErEAgoNzeXeewHzGX/YS77B/PYf5jL/mGMUXt7u3JycuR0XvgqkwF5BMXpdOorX/mKJMntdvOPpR8wj/2Huew/zGX/YB77D3N58Xo6cnIOF8kCAADrEFAAAIB1BmxAcblceuihh+RyuRLdyoDGPPYf5rL/MJf9g3nsP8xl/A3Ii2QBAMDgNmCPoAAAgMGLgAIAAKxDQAEAANYhoAAAAOsMyIDyzDPP6PLLL9ewYcNUWFiod999N9EtWWfXrl266aablJOTI4fDoZdffjlq3BijBx98UNnZ2Ro+fLiKior00UcfRdWcPn1aixYtktvtVnp6upYuXaozZ87EcS8Sr6qqSrNmzdKoUaOUmZmphQsXqrGxMaqms7NTZWVlGj16tNLS0lRaWqrm5uaomqamJpWUlGjEiBHKzMzU6tWr1d3dHc9dSaj169eroKAg8pArn8+nbdu2RcaZw7577LHH5HA4tHLlysg65rN3Hn74YTkcjqhlypQpkXHmMcHMALN582aTmppq/uM//sMcPHjQLFu2zKSnp5vm5uZEt2aV1157zfzTP/2Teemll4wks2XLlqjxxx57zHg8HvPyyy+bP/zhD+a73/2umTBhgvn0008jNfPnzzfTp08377zzjvnd735nJk6caO64444470liFRcXm+eee84cOHDANDQ0mBtuuMHk5eWZM2fORGruuecek5uba2pqasx7771nZs+ebf7mb/4mMt7d3W2mTp1qioqKzN69e81rr71mxowZYyorKxOxSwnxX//1X+Y3v/mN+eMf/2gaGxvNj370I5OSkmIOHDhgjGEO++rdd981l19+uSkoKDDf//73I+uZz9556KGHzNVXX21OnDgRWU6ePBkZZx4Ta8AFlGuvvdaUlZVFXodCIZOTk2OqqqoS2JXdPh9QwuGw8Xq95vHHH4+sa21tNS6Xy/ziF78wxhhz6NAhI8ns2bMnUrNt2zbjcDjMX/7yl7j1bpuWlhYjydTW1hpjPpu3lJQU8+KLL0ZqPvjgAyPJ1NXVGWM+C4tOp9P4/f5Izfr1643b7TbBYDC+O2CRyy67zPzbv/0bc9hH7e3tZtKkSaa6utp861vfigQU5rP3HnroITN9+vTzjjGPiTegTvGcPXtW9fX1KioqiqxzOp0qKipSXV1dAjsbWI4ePSq/3x81jx6PR4WFhZF5rKurU3p6umbOnBmpKSoqktPp1O7du+Pesy3a2tokSRkZGZKk+vp6dXV1Rc3llClTlJeXFzWX06ZNU1ZWVqSmuLhYgUBABw8ejGP3dgiFQtq8ebM6Ojrk8/mYwz4qKytTSUlJ1LxJ/JuM1UcffaScnBxdccUVWrRokZqamiQxjzYYUF8W+Ne//lWhUCjqH4MkZWVl6cMPP0xQVwOP3++XpPPO47kxv9+vzMzMqPHk5GRlZGREaoaacDislStX6rrrrtPUqVMlfTZPqampSk9Pj6r9/Fyeb67PjQ0V+/fvl8/nU2dnp9LS0rRlyxbl5+eroaGBOYzR5s2b9f7772vPnj1fGOPfZO8VFhZq48aNmjx5sk6cOKG1a9fqm9/8pg4cOMA8WmBABRQgkcrKynTgwAG99dZbiW5lQJo8ebIaGhrU1tamX//611qyZIlqa2sT3daAc+zYMX3/+99XdXW1hg0bluh2BrQFCxZEfi4oKFBhYaHGjx+vX/3qVxo+fHgCO4M0wO7iGTNmjJKSkr5wFXVzc7O8Xm+Cuhp4zs3VhebR6/WqpaUlary7u1unT58eknNdXl6urVu36s0339S4ceMi671er86ePavW1tao+s/P5fnm+tzYUJGamqqJEydqxowZqqqq0vTp0/XUU08xhzGqr69XS0uLvv71rys5OVnJycmqra3V008/reTkZGVlZTGffZSenq6vfvWrOnz4MP8uLTCgAkpqaqpmzJihmpqayLpwOKyamhr5fL4EdjawTJgwQV6vN2oeA4GAdu/eHZlHn8+n1tZW1dfXR2p27NihcDiswsLCuPecKMYYlZeXa8uWLdqxY4cmTJgQNT5jxgylpKREzWVjY6Oampqi5nL//v1Rga+6ulput1v5+fnx2RELhcNhBYNB5jBGc+fO1f79+9XQ0BBZZs6cqUWLFkV+Zj775syZMzpy5Iiys7P5d2mDRF+lG6vNmzcbl8tlNm7caA4dOmSWL19u0tPTo66ixmdX+O/du9fs3bvXSDL/8i//Yvbu3Wv+/Oc/G2M+u804PT3dvPLKK2bfvn3m5ptvPu9txl/72tfM7t27zVtvvWUmTZo05G4zXrFihfF4PGbnzp1RtyJ+8sknkZp77rnH5OXlmR07dpj33nvP+Hw+4/P5IuPnbkWcN2+eaWhoMNu3bzdjx44dUrci3n///aa2ttYcPXrU7Nu3z9x///3G4XCY3/72t8YY5vBi/e+7eIxhPnvrvvvuMzt37jRHjx41v//9701RUZEZM2aMaWlpMcYwj4k24AKKMcb8/Oc/N3l5eSY1NdVce+215p133kl0S9Z58803jaQvLEuWLDHGfHar8Y9//GOTlZVlXC6XmTt3rmlsbIzaxqlTp8wdd9xh0tLSjNvtNnfffbdpb29PwN4kzvnmUJJ57rnnIjWffvqp+cd//Edz2WWXmREjRpi/+7u/MydOnIjazp/+9CezYMECM3z4cDNmzBhz3333ma6urjjvTeL8wz/8gxk/frxJTU01Y8eONXPnzo2EE2OYw4v1+YDCfPbObbfdZrKzs01qaqr5yle+Ym677TZz+PDhyDjzmFgOY4xJzLEbAACA8xtQ16AAAIChgYACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOv8PwKm5jkQ/ATvAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态\n",
      "state= [-0.01169164  0.03815211 -0.04544982 -0.0105797 ]\n",
      "这个游戏一共有2个动作,不是0就是1\n",
      "env.action_space= Discrete(2)\n",
      "随机一个动作\n",
      "action= 1\n",
      "执行一个动作,得到下一个状态,奖励,是否结束\n",
      "state= [-0.01092859  0.2338954  -0.04566142 -0.3172491 ]\n",
      "reward= 1.0\n",
      "over= False\n"
     ]
    }
   ],
   "source": [
    "#测试游戏环境\n",
    "def test_env():\n",
    "    state = env.reset()\n",
    "    print('这个游戏的状态用4个数字表示,我也不知道这4个数字分别是什么意思,反正这4个数字就能描述游戏全部的状态')\n",
    "    print('state=', state)\n",
    "    #state= [ 0.03490619  0.04873464  0.04908862 -0.00375859]\n",
    "\n",
    "    print('这个游戏一共有2个动作,不是0就是1')\n",
    "    print('env.action_space=', env.action_space)\n",
    "    #env.action_space= Discrete(2)\n",
    "\n",
    "    print('随机一个动作')\n",
    "    action = env.action_space.sample()\n",
    "    print('action=', action)\n",
    "    #action= 1\n",
    "\n",
    "    print('执行一个动作,得到下一个状态,奖励,是否结束')\n",
    "    state, reward, over, _ = env.step(action)\n",
    "\n",
    "    print('state=', state)\n",
    "    #state= [ 0.02018229 -0.16441101  0.01547085  0.2661691 ]\n",
    "\n",
    "    print('reward=', reward)\n",
    "    #reward= 1.0\n",
    "\n",
    "    print('over=', over)\n",
    "    #over= False\n",
    "\n",
    "\n",
    "test_env()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=4, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=2, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=4, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=2, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#计算动作的模型,也是真正要用的模型\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 2),\n",
    ")\n",
    "\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "#得到一个动作\n",
    "def get_action(state):\n",
    "    if random.random() < 0.01:\n",
    "        return random.choice([0, 1])\n",
    "\n",
    "    #走神经网络,得到一个动作\n",
    "    state = torch.FloatTensor(state).reshape(1, 4)\n",
    "\n",
    "    return model(state).argmax().item()\n",
    "\n",
    "\n",
    "get_action([0.0013847, -0.01194451, 0.04260966, 0.00688801])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((205, 0), 205)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step(action)\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 10000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 10000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1185/2021397730.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[ 3.1741e-02,  2.0289e-02,  7.9580e-03, -2.4187e-02],\n",
       "         [ 7.5212e-02,  1.3804e+00, -1.6951e-01, -2.1854e+00],\n",
       "         [ 6.7671e-03,  2.8668e-02, -3.0978e-02, -2.3809e-02],\n",
       "         [ 4.8752e-02,  1.0071e+00, -9.4027e-02, -1.5546e+00],\n",
       "         [-1.6167e-02,  1.7738e-01, -3.2740e-02, -3.3079e-01],\n",
       "         [ 1.4403e-02,  2.1756e-01, -2.6746e-02, -3.1935e-01],\n",
       "         [ 2.3100e-02,  4.0988e-01,  2.8089e-02, -6.0916e-01],\n",
       "         [ 4.2824e-02,  7.6068e-01, -3.9855e-02, -1.2182e+00],\n",
       "         [-4.1832e-02,  2.3190e-01,  2.8036e-02, -2.3532e-01],\n",
       "         [-4.3200e-02,  1.6935e-01, -3.2603e-02, -3.2165e-01],\n",
       "         [-5.1604e-03,  5.6860e-01, -5.2028e-02, -9.3843e-01],\n",
       "         [ 6.8047e-02,  1.3233e+00, -1.6086e-01, -2.2005e+00],\n",
       "         [ 7.7163e-02,  1.1521e+00, -9.4681e-02, -1.8351e+00],\n",
       "         [-1.8315e-02, -4.0477e-02,  1.9169e-02, -3.1452e-02],\n",
       "         [ 5.3453e-02,  5.4993e-01,  8.3369e-03, -8.7470e-01],\n",
       "         [ 1.6690e-02,  9.4573e-01, -7.7105e-02, -1.4961e+00],\n",
       "         [ 7.9273e-02,  1.1901e+00, -5.5038e-02, -1.7744e+00],\n",
       "         [ 2.0220e-02,  6.1541e-01, -5.0547e-02, -9.3293e-01],\n",
       "         [-1.9124e-02,  1.5437e-01,  1.8540e-02, -3.1803e-01],\n",
       "         [ 9.7863e-03,  9.5410e-01, -2.3370e-02, -1.4712e+00],\n",
       "         [ 5.0342e-02,  8.0036e-01, -5.3790e-02, -1.2100e+00],\n",
       "         [-1.6037e-02,  3.4922e-01,  1.2179e-02, -6.0480e-01],\n",
       "         [ 1.3657e-01,  1.5975e+00, -1.2810e-01, -2.3496e+00],\n",
       "         [ 1.2087e-01,  1.3314e+00, -9.6917e-02, -2.0741e+00],\n",
       "         [ 3.8247e-02,  6.0477e-01, -3.5663e-02, -9.0634e-01],\n",
       "         [-3.3226e-02,  6.3093e-01, -4.5812e-02, -8.8013e-01],\n",
       "         [ 2.8371e-02,  1.2112e+00, -4.1114e-02, -1.6973e+00],\n",
       "         [ 6.9043e-02,  1.6111e+00, -1.9635e-01, -2.4798e+00],\n",
       "         [ 4.0747e-02,  1.4148e+00, -1.5347e-01, -2.1440e+00],\n",
       "         [ 1.4016e-03,  5.3927e-01, -4.8963e-02, -9.3143e-01],\n",
       "         [-4.6730e-02,  2.3982e-01, -2.8771e-02, -2.7523e-01],\n",
       "         [-8.3666e-03,  8.2095e-01,  9.1111e-03, -1.1107e+00],\n",
       "         [ 7.2513e-02,  1.5985e+00, -1.2637e-01, -2.3137e+00],\n",
       "         [-2.4064e-02,  3.6942e-01,  3.0204e-02, -6.0856e-01],\n",
       "         [-6.1769e-03,  9.5229e-01, -9.4859e-02, -1.5514e+00],\n",
       "         [ 7.8762e-02,  1.5411e+00, -1.2981e-01, -2.3983e+00],\n",
       "         [ 1.1654e-01,  1.3878e+00, -1.1867e-01, -2.1205e+00],\n",
       "         [-1.6675e-02,  5.6411e-01,  1.8033e-02, -8.9158e-01],\n",
       "         [ 1.0958e-01,  1.7371e+00, -1.7777e-01, -2.7279e+00],\n",
       "         [ 1.1278e-01,  1.7986e+00, -1.6146e-01, -2.6432e+00],\n",
       "         [ 4.1811e-02, -4.8363e-02, -3.5341e-04, -3.6730e-02],\n",
       "         [ 1.7805e-01,  1.7234e+00, -1.8630e-01, -2.7270e+00],\n",
       "         [ 4.0844e-02,  1.4676e-01, -1.0880e-03, -3.2952e-01],\n",
       "         [ 3.1298e-02,  6.0460e-01,  1.5906e-02, -8.9286e-01],\n",
       "         [ 1.8382e-02,  2.0671e-02,  3.5568e-02, -4.6345e-02],\n",
       "         [-1.2619e-02,  3.7295e-01, -3.9356e-02, -6.3362e-01],\n",
       "         [ 2.6887e-02,  9.3094e-01, -9.2373e-02, -1.5522e+00],\n",
       "         [-2.9522e-02,  4.3159e-01,  3.6541e-02, -5.4527e-01],\n",
       "         [-1.9209e-02, -3.1839e-02, -1.8596e-02,  1.3546e-02],\n",
       "         [-4.2676e-02, -2.6219e-02, -3.2224e-02, -1.8975e-02],\n",
       "         [-2.8662e-02,  6.2139e-01,  1.2949e-02, -8.0427e-01],\n",
       "         [ 6.6349e-02,  9.9613e-01, -7.7990e-02, -1.5191e+00],\n",
       "         [ 3.5837e-02,  1.3447e+00, -1.6333e-01, -2.2010e+00],\n",
       "         [ 2.6980e-02,  4.2569e-01,  2.5938e-02, -5.5649e-01],\n",
       "         [ 6.2731e-02,  1.5409e+00, -2.0735e-01, -2.5393e+00],\n",
       "         [ 1.2869e-02,  1.1484e+00, -1.2589e-01, -1.8721e+00],\n",
       "         [-4.1933e-02,  4.3534e-01, -3.4275e-02, -5.7685e-01],\n",
       "         [-3.9813e-02,  3.6492e-01, -3.9036e-02, -6.2443e-01],\n",
       "         [ 2.2360e-02,  2.3103e-01,  3.1415e-02, -2.7388e-01],\n",
       "         [ 7.5278e-02,  1.1963e+00, -1.1944e-01, -1.8609e+00],\n",
       "         [ 5.8439e-02,  1.3378e+00, -1.4326e-01, -2.1358e+00],\n",
       "         [-1.6234e-02,  8.1634e-01, -3.1366e-03, -1.0928e+00],\n",
       "         [ 1.6242e-01,  1.7777e+00, -1.8025e-01, -2.7331e+00],\n",
       "         [-1.6575e-02,  3.5892e-01, -2.4024e-02, -5.8335e-01]]),\n",
       " tensor([[1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1],\n",
       "         [1]]),\n",
       " tensor([[1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.],\n",
       "         [1.]]),\n",
       " tensor([[ 3.2146e-02,  2.1530e-01,  7.4743e-03, -3.1435e-01],\n",
       "         [ 1.0282e-01,  1.5767e+00, -2.1322e-01, -2.5252e+00],\n",
       "         [ 7.3404e-03,  2.2422e-01, -3.1455e-02, -3.2610e-01],\n",
       "         [ 6.8894e-02,  1.2032e+00, -1.2512e-01, -1.8751e+00],\n",
       "         [-1.2619e-02,  3.7295e-01, -3.9356e-02, -6.3362e-01],\n",
       "         [ 1.8754e-02,  4.1305e-01, -3.3133e-02, -6.2035e-01],\n",
       "         [ 3.1298e-02,  6.0460e-01,  1.5906e-02, -8.9286e-01],\n",
       "         [ 5.8038e-02,  9.5629e-01, -6.4219e-02, -1.5231e+00],\n",
       "         [-3.7194e-02,  4.2661e-01,  2.3329e-02, -5.1903e-01],\n",
       "         [-3.9813e-02,  3.6492e-01, -3.9036e-02, -6.2443e-01],\n",
       "         [ 6.2117e-03,  7.6439e-01, -7.0797e-02, -1.2470e+00],\n",
       "         [ 9.4512e-02,  1.5195e+00, -2.0487e-01, -2.5382e+00],\n",
       "         [ 1.0021e-01,  1.3482e+00, -1.3138e-01, -2.1556e+00],\n",
       "         [-1.9124e-02,  1.5437e-01,  1.8540e-02, -3.1803e-01],\n",
       "         [ 6.4452e-02,  7.4494e-01, -9.1572e-03, -1.1648e+00],\n",
       "         [ 3.5605e-02,  1.1417e+00, -1.0703e-01, -1.8118e+00],\n",
       "         [ 1.0307e-01,  1.3858e+00, -9.0525e-02, -2.0836e+00],\n",
       "         [ 3.2528e-02,  8.1117e-01, -6.9206e-02, -1.2411e+00],\n",
       "         [-1.6037e-02,  3.4922e-01,  1.2179e-02, -6.0480e-01],\n",
       "         [ 2.8868e-02,  1.1495e+00, -5.2793e-02, -1.7711e+00],\n",
       "         [ 6.6349e-02,  9.9613e-01, -7.7990e-02, -1.5191e+00],\n",
       "         [-9.0525e-03,  5.4417e-01,  8.2930e-05, -8.9363e-01],\n",
       "         [ 1.6852e-01,  1.7935e+00, -1.7509e-01, -2.6788e+00],\n",
       "         [ 1.4750e-01,  1.5274e+00, -1.3840e-01, -2.3951e+00],\n",
       "         [ 5.0342e-02,  8.0036e-01, -5.3790e-02, -1.2100e+00],\n",
       "         [-2.0608e-02,  8.2664e-01, -6.3415e-02, -1.1869e+00],\n",
       "         [ 5.2596e-02,  1.4068e+00, -7.5060e-02, -2.0025e+00],\n",
       "         [ 1.0126e-01,  1.8072e+00, -2.4595e-01, -2.8258e+00],\n",
       "         [ 6.9043e-02,  1.6111e+00, -1.9635e-01, -2.4798e+00],\n",
       "         [ 1.2187e-02,  7.3502e-01, -6.7591e-02, -1.2391e+00],\n",
       "         [-4.1933e-02,  4.3534e-01, -3.4275e-02, -5.7685e-01],\n",
       "         [ 8.0523e-03,  1.0159e+00, -1.3104e-02, -1.4005e+00],\n",
       "         [ 1.0448e-01,  1.7946e+00, -1.7265e-01, -2.6424e+00],\n",
       "         [-1.6675e-02,  5.6411e-01,  1.8033e-02, -8.9158e-01],\n",
       "         [ 1.2869e-02,  1.1484e+00, -1.2589e-01, -1.8721e+00],\n",
       "         [ 1.0958e-01,  1.7371e+00, -1.7777e-01, -2.7279e+00],\n",
       "         [ 1.4430e-01,  1.5839e+00, -1.6108e-01, -2.4473e+00],\n",
       "         [-5.3932e-03,  7.5898e-01,  2.0085e-04, -1.1785e+00],\n",
       "         [ 1.4432e-01,  1.9330e+00, -2.3233e-01, -3.0690e+00],\n",
       "         [ 1.4876e-01,  1.9945e+00, -2.1432e-01, -2.9805e+00],\n",
       "         [ 4.0844e-02,  1.4676e-01, -1.0880e-03, -3.2952e-01],\n",
       "         [ 2.1251e-01,  1.9193e+00, -2.4084e-01, -3.0702e+00],\n",
       "         [ 4.3779e-02,  3.4190e-01, -7.6785e-03, -6.2255e-01],\n",
       "         [ 4.3390e-02,  7.9950e-01, -1.9516e-03, -1.1805e+00],\n",
       "         [ 1.8795e-02,  2.1526e-01,  3.4641e-02, -3.2760e-01],\n",
       "         [-5.1604e-03,  5.6860e-01, -5.2028e-02, -9.3843e-01],\n",
       "         [ 4.5506e-02,  1.1270e+00, -1.2342e-01, -1.8722e+00],\n",
       "         [-2.0890e-02,  6.2618e-01,  2.5636e-02, -8.2622e-01],\n",
       "         [-1.9846e-02,  1.6355e-01, -1.8325e-02, -2.8495e-01],\n",
       "         [-4.3200e-02,  1.6935e-01, -3.2603e-02, -3.2165e-01],\n",
       "         [-1.6234e-02,  8.1634e-01, -3.1366e-03, -1.0928e+00],\n",
       "         [ 8.6272e-02,  1.1921e+00, -1.0837e-01, -1.8350e+00],\n",
       "         [ 6.2731e-02,  1.5409e+00, -2.0735e-01, -2.5393e+00],\n",
       "         [ 3.5494e-02,  6.2044e-01,  1.4808e-02, -8.4089e-01],\n",
       "         [ 9.3549e-02,  1.7371e+00, -2.5814e-01, -2.8877e+00],\n",
       "         [ 3.5837e-02,  1.3447e+00, -1.6333e-01, -2.2010e+00],\n",
       "         [-3.3226e-02,  6.3093e-01, -4.5812e-02, -8.8013e-01],\n",
       "         [-3.2515e-02,  5.6056e-01, -5.1525e-02, -9.2915e-01],\n",
       "         [ 2.6980e-02,  4.2569e-01,  2.5938e-02, -5.5649e-01],\n",
       "         [ 9.9204e-02,  1.3925e+00, -1.5666e-01, -2.1882e+00],\n",
       "         [ 8.5196e-02,  1.5341e+00, -1.8598e-01, -2.4691e+00],\n",
       "         [ 9.2303e-05,  1.0115e+00, -2.4994e-02, -1.3865e+00],\n",
       "         [ 1.9798e-01,  1.9736e+00, -2.3492e-01, -3.0748e+00],\n",
       "         [-9.3964e-03,  5.5437e-01, -3.5691e-02, -8.8350e-01]]),\n",
       " tensor([[0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [1],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 4]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "    #[b, 4]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 4)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 4] -> [b, 2]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 2] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 1])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 4] -> [b, 2]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 2] -> [b, 1]\n",
    "    target = target.max(dim=1)[0]\n",
    "    target = target.reshape(-1, 1)\n",
    "\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target *= 0.98\n",
    "\n",
    "    #如果next_state已经游戏结束,则next_state的分数是0\n",
    "    #因为如果下一步已经游戏结束,显然不需要再继续玩下去,也就不需要考虑next_state了.\n",
    "    #[b, 1] * [b, 1] -> [b, 1]\n",
    "    target *= (1 - over)\n",
    "\n",
    "    #加上reward就是最终的分数\n",
    "    #[b, 1] + [b, 1] -> [b, 1]\n",
    "    target += reward\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8.0"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step(action)\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "OHoSU6uI-xIt",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 406 201 0 11.2\n",
      "50 10000 200 200 198.6\n",
      "100 10000 344 344 199.7\n",
      "150 10000 200 200 199.45\n",
      "200 10000 340 340 200.0\n",
      "250 10000 200 200 200.0\n",
      "300 10000 200 200 190.4\n",
      "350 10000 200 200 191.7\n",
      "400 10000 200 200 200.0\n",
      "450 10000 341 341 169.55\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(500):\n",
    "        #更新N条数据\n",
    "        update_count, drop_count = update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(200):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 10 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 50 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(epoch, len(datas), update_count, drop_count, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAF7CAYAAAD4/3BBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAn0UlEQVR4nO3dfXSU5Z3/8c/kkYcwEwMkk5QEURCIEOwChllbS0tKgOjKGveoZSF2OXBkE08hlmK6VMTuMS7uWR+6CH9sV9xzpLT0iK5UsDFIqDU8mJLypKnwYxssTILyy0wSJSSZ6/eHP+6zo4hMCMw1w/t1zn3OzH1dc9/f+zoc5pPrfhiXMcYIAADAIgnRLgAAAODzCCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDpRDShr167V9ddfrwEDBqiwsFB79+6NZjkAAMASUQsov/zlL1VZWalVq1bpD3/4gyZNmqTi4mK1trZGqyQAAGAJV7R+LLCwsFBTp07Vv//7v0uSQqGQcnNz9dBDD+mRRx6JRkkAAMASSdHY6blz59TQ0KCqqipnXUJCgoqKilRfX/+F/l1dXerq6nLeh0IhnTlzRkOHDpXL5boqNQMAgMtjjFF7e7tycnKUkHDxkzhRCSgfffSRent7lZWVFbY+KytL77///hf6V1dXa/Xq1VerPAAAcAWdOHFCI0aMuGifqASUSFVVVamystJ5HwgElJeXpxMnTsjtdkexMgAAcKmCwaByc3M1ZMiQr+wblYAybNgwJSYmqqWlJWx9S0uLvF7vF/qnpqYqNTX1C+vdbjcBBQCAGHMpl2dE5S6elJQUTZ48WbW1tc66UCik2tpa+Xy+aJQEAAAsErVTPJWVlSorK9OUKVN066236plnnlFnZ6e+//3vR6skAABgiagFlHvvvVenT5/Wo48+Kr/fr1tuuUXbt2//woWzAADg2hO156BcjmAwKI/Ho0AgwDUoAADEiEi+v/ktHgAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6/R7QHnsscfkcrnClnHjxjntZ8+eVXl5uYYOHaq0tDSVlpaqpaWlv8sAAAAx7IrMoNx88806deqUs7z99ttO27Jly/Taa69p8+bNqqur08mTJ3X33XdfiTIAAECMSroiG01Kktfr/cL6QCCgn//859q4caO+853vSJJeeOEFjR8/Xrt379a0adOuRDkAACDGXJEZlA8++EA5OTm64YYbNG/ePDU3N0uSGhoa1N3draKiIqfvuHHjlJeXp/r6+i/dXldXl4LBYNgCAADiV78HlMLCQm3YsEHbt2/XunXrdPz4cX3zm99Ue3u7/H6/UlJSlJ6eHvaZrKws+f3+L91mdXW1PB6Ps+Tm5vZ32QAAwCL9fopn9uzZzuuCggIVFhZq5MiR+tWvfqWBAwf2aZtVVVWqrKx03geDQUIKAABx7IrfZpyenq6bbrpJR48eldfr1blz59TW1hbWp6Wl5YLXrJyXmpoqt9sdtgAAgPh1xQNKR0eHjh07puzsbE2ePFnJycmqra112puamtTc3Cyfz3elSwEAADGi30/x/PCHP9Sdd96pkSNH6uTJk1q1apUSExN1//33y+PxaOHChaqsrFRGRobcbrceeugh+Xw+7uABAACOfg8oH374oe6//359/PHHGj58uL7xjW9o9+7dGj58uCTp6aefVkJCgkpLS9XV1aXi4mI9//zz/V0GAACIYS5jjIl2EZEKBoPyeDwKBAJcjwIAQIyI5Pub3+IBAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFgn4oCya9cu3XnnncrJyZHL5dIrr7wS1m6M0aOPPqrs7GwNHDhQRUVF+uCDD8L6nDlzRvPmzZPb7VZ6eroWLlyojo6OyzoQAAAQPyIOKJ2dnZo0aZLWrl17wfY1a9boueee0/r167Vnzx4NHjxYxcXFOnv2rNNn3rx5Onz4sGpqarR161bt2rVLixcv7vtRAACAuOIyxpg+f9jl0pYtWzR37lxJn82e5OTk6OGHH9YPf/hDSVIgEFBWVpY2bNig++67T++9957y8/O1b98+TZkyRZK0fft2zZkzRx9++KFycnK+cr/BYFAej0eBQEBut7uv5QMAgKsoku/vfr0G5fjx4/L7/SoqKnLWeTweFRYWqr6+XpJUX1+v9PR0J5xIUlFRkRISErRnz54Lbrerq0vBYDBsAQAA8atfA4rf75ckZWVlha3Pyspy2vx+vzIzM8Pak5KSlJGR4fT5vOrqank8HmfJzc3tz7IBAIBlYuIunqqqKgUCAWc5ceJEtEsCAABXUL8GFK/XK0lqaWkJW9/S0uK0eb1etba2hrX39PTozJkzTp/PS01NldvtDlsAAED86teAMmrUKHm9XtXW1jrrgsGg9uzZI5/PJ0ny+Xxqa2tTQ0OD02fHjh0KhUIqLCzsz3IAAECMSor0Ax0dHTp69Kjz/vjx42psbFRGRoby8vK0dOlS/fM//7PGjBmjUaNG6Sc/+YlycnKcO33Gjx+vWbNmadGiRVq/fr26u7tVUVGh++6775Lu4AEAAPEv4oDy7rvv6tvf/rbzvrKyUpJUVlamDRs26Ec/+pE6Ozu1ePFitbW16Rvf+Ia2b9+uAQMGOJ956aWXVFFRoRkzZighIUGlpaV67rnn+uFwAABAPLis56BEC89BAQAg9kTtOSgAAAD9gYACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6EQeUXbt26c4771ROTo5cLpdeeeWVsPYHHnhALpcrbJk1a1ZYnzNnzmjevHlyu91KT0/XwoUL1dHRcVkHAgAA4kfEAaWzs1OTJk3S2rVrv7TPrFmzdOrUKWf5xS9+EdY+b948HT58WDU1Ndq6dat27dqlxYsXR149AACIS0mRfmD27NmaPXv2RfukpqbK6/VesO29997T9u3btW/fPk2ZMkWS9LOf/Uxz5szRv/7rvyonJyfSkgAAQJy5Iteg7Ny5U5mZmRo7dqyWLFmijz/+2Gmrr69Xenq6E04kqaioSAkJCdqzZ88Ft9fV1aVgMBi2AACA+NXvAWXWrFn6r//6L9XW1upf/uVfVFdXp9mzZ6u3t1eS5Pf7lZmZGfaZpKQkZWRkyO/3X3Cb1dXV8ng8zpKbm9vfZQMAAItEfIrnq9x3333O64kTJ6qgoEA33nijdu7cqRkzZvRpm1VVVaqsrHTeB4NBQgoAAHHsit9mfMMNN2jYsGE6evSoJMnr9aq1tTWsT09Pj86cOfOl162kpqbK7XaHLQAAIH5d8YDy4Ycf6uOPP1Z2drYkyefzqa2tTQ0NDU6fHTt2KBQKqbCw8EqXAwAAYkDEp3g6Ojqc2RBJOn78uBobG5WRkaGMjAytXr1apaWl8nq9OnbsmH70ox9p9OjRKi4uliSNHz9es2bN0qJFi7R+/Xp1d3eroqJC9913H3fwAAAASZLLGGMi+cDOnTv17W9/+wvry8rKtG7dOs2dO1f79+9XW1ubcnJyNHPmTP30pz9VVlaW0/fMmTOqqKjQa6+9poSEBJWWluq5555TWlraJdUQDAbl8XgUCAQ43QMAQIyI5Ps74oBiAwIKAACxJ5Lvb36LBwAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsE/GPBQLAlXS87kV1fxK8aJ/cafdo4HXZV6kiANFAQAFglfa/NKmr/aOL9sn++mwZY+Ryua5SVQCuNk7xAIg9sfcbpwAiREABEHOMCUW7BABXGAEFQAxiBgWIdwQUADGHMzxA/COgAIg9JBQg7hFQAMQcrkEB4h8BBUAMYgYFiHcEFACxh1M8QNwjoACIPQQUIO4RUADEnM+uQSGkAPGMgAIgBhFOgHhHQAEQcwyneIC4R0ABEHvIJ0DcI6AAiDk8BwWIfwQUALHHGGZRgDhHQAEQe5hBAeIeAQVAzDFMnwBxj4ACIPZwFw8Q9yIKKNXV1Zo6daqGDBmizMxMzZ07V01NTWF9zp49q/Lycg0dOlRpaWkqLS1VS0tLWJ/m5maVlJRo0KBByszM1PLly9XT03P5RwPg2kBAAeJeRAGlrq5O5eXl2r17t2pqatTd3a2ZM2eqs7PT6bNs2TK99tpr2rx5s+rq6nTy5EndfffdTntvb69KSkp07tw5vfPOO3rxxRe1YcMGPfroo/13VADi2mfPQSGkAPHMZS7jiUenT59WZmam6urqdPvttysQCGj48OHauHGj7rnnHknS+++/r/Hjx6u+vl7Tpk3Ttm3bdMcdd+jkyZPKysqSJK1fv14rVqzQ6dOnlZKS8pX7DQaD8ng8CgQCcrvdfS0fgIUObPyxuto/umif67+1QMNu8smVkHiVqgLQHyL5/r6sa1ACgYAkKSMjQ5LU0NCg7u5uFRUVOX3GjRunvLw81dfXS5Lq6+s1ceJEJ5xIUnFxsYLBoA4fPnzB/XR1dSkYDIYtAK5dPEkWiH99DiihUEhLly7VbbfdpgkTJkiS/H6/UlJSlJ6eHtY3KytLfr/f6fO/w8n59vNtF1JdXS2Px+Msubm5fS0bQDwgoABxr88Bpby8XIcOHdKmTZv6s54LqqqqUiAQcJYTJ05c8X0CsBdPkgXiX1JfPlRRUaGtW7dq165dGjFihLPe6/Xq3LlzamtrC5tFaWlpkdfrdfrs3bs3bHvn7/I53+fzUlNTlZqa2pdSAcQlZlCAeBfRDIoxRhUVFdqyZYt27NihUaNGhbVPnjxZycnJqq2tddY1NTWpublZPp9PkuTz+XTw4EG1trY6fWpqauR2u5Wfn385xwLgWsEpHiDuRTSDUl5ero0bN+rVV1/VkCFDnGtGPB6PBg4cKI/Ho4ULF6qyslIZGRlyu9166KGH5PP5NG3aNEnSzJkzlZ+fr/nz52vNmjXy+/1auXKlysvLmSUBcEm4SBaIfxEFlHXr1kmSpk+fHrb+hRde0AMPPCBJevrpp5WQkKDS0lJ1dXWpuLhYzz//vNM3MTFRW7du1ZIlS+Tz+TR48GCVlZXp8ccfv7wjAXDtIKAAce+ynoMSLTwHBYhfl/IclFzf3ylrwnd4DgoQY67ac1AAIBqM4ecCgXhHQAEQe4zhNA8Q5wgoAGIOz0EB4h8BBUDsYfYEiHsEFACxh4ACxD0CCoCYE4M3HwKIEAEFQAwioADxjoACIOYwgwLEPwIKgNjDXTxA3COgAIg5zKAA8Y+AAiAGEVCAeEdAARB7mEEB4h4BBUDsMUbMogDxjYACIOZwDQoQ/wgoAGIQAQWIdwQUADGHGRQg/hFQAMQeE2ISBYhzBBQAMYcZFCD+EVAAxCACChDvCCgAYg4zKED8I6AAiD0EFCDuEVAAWCUte8xX9uls/T8yod6rUA2AaCGgALDK4GEjJbku2ufTM3+RMQQUIJ4RUADYxXXxcALg2kBAAWAXAgoAEVAAWMb1Fad3AFwbCCgA7OLivyUABBQAlnFxigeACCgAbENAASACCgDbEFAAiIACwDKc4gEgRRhQqqurNXXqVA0ZMkSZmZmaO3eumpqawvpMnz5dLpcrbHnwwQfD+jQ3N6ukpESDBg1SZmamli9frp6enss/GgAxz8VFsgAkJUXSua6uTuXl5Zo6dap6enr04x//WDNnztSRI0c0ePBgp9+iRYv0+OOPO+8HDRrkvO7t7VVJSYm8Xq/eeecdnTp1SgsWLFBycrKeeOKJfjgkADGNGRQAijCgbN++Pez9hg0blJmZqYaGBt1+++3O+kGDBsnr9V5wG7/97W915MgRvfnmm8rKytItt9yin/70p1qxYoUee+wxpaSk9OEwAMQPAgqAy7wGJRAISJIyMjLC1r/00ksaNmyYJkyYoKqqKn3yySdOW319vSZOnKisrCxnXXFxsYLBoA4fPnzB/XR1dSkYDIYtAOIT16AAkCKcQfnfQqGQli5dqttuu00TJkxw1n/ve9/TyJEjlZOTowMHDmjFihVqamrSyy+/LEny+/1h4USS897v919wX9XV1Vq9enVfSwUQS7gGBYAuI6CUl5fr0KFDevvtt8PWL1682Hk9ceJEZWdna8aMGTp27JhuvPHGPu2rqqpKlZWVzvtgMKjc3Ny+FQ7AasygAJD6eIqnoqJCW7du1VtvvaURI0ZctG9hYaEk6ejRo5Ikr9erlpaWsD7n33/ZdSupqalyu91hC4B4RUABEGFAMcaooqJCW7Zs0Y4dOzRq1Kiv/ExjY6MkKTs7W5Lk8/l08OBBtba2On1qamrkdruVn58fSTkA4lECp3gARHiKp7y8XBs3btSrr76qIUOGONeMeDweDRw4UMeOHdPGjRs1Z84cDR06VAcOHNCyZct0++23q6CgQJI0c+ZM5efna/78+VqzZo38fr9Wrlyp8vJypaam9v8RAogpLpfrs0kUE+1KAERTRH+qrFu3ToFAQNOnT1d2draz/PKXv5QkpaSk6M0339TMmTM1btw4PfzwwyotLdVrr73mbCMxMVFbt25VYmKifD6f/v7v/14LFiwIe24KgGsZp3gARDiDYszF/6TJzc1VXV3dV25n5MiRev311yPZNYBrBE+SBSDxWzwAbMNdPABEQAFgGW4zBiARUADYhoACQAQUALbhGhQAIqAAsAyneABIBBQAliGgAJAIKACsQ0ABQEABYBuuQQEgAgoA23CKB4AIKAAswzUoACQCCgDrEFAAEFAAWIbf4gEgEVAA2IZTPABEQAFgGa5BASARUADYhlM8AERAAWAZZlAASAQUANYhoAAgoACwDDMoACQCCgDbcA0KAElJ0S4AQHwJhUIKhUJ9/nxvqPfS+vX0Sok9fd6Py+VSYmJinz8P4MriTxUA/eqRRx7RwIED+7zccMON6u396pCSnZ19Wft54IEHrvxgAOgzZlAA9KtQKKSenr7PbJzr7r6kfj09PZe1n0sJQQCih4ACwCohYyTz2euOnnR93J2jrtBApSSc1XXJfnmSPo5ugQCuCgIKAKsY81k6OdOdpSMdt+mTkFu9JlmJ6tHAxKBuGvSuslL/HOUqAVxpXIMCwCqhkFFnr0d/CBarvXeYek2KJJd6layO3qE60DFd/7c7M9plArjCCCgArBIKJeh3bX+nbjPggu09JlW7A3ep26Re5coAXE0EFABWCRmjr36aLA9zA+IdAQWAVc5fgwLg2kZAAWCVXgIKABFQANjGhOTzbFGCLvyMkwT1aor7N0pynbvKhQG4miIKKOvWrVNBQYHcbrfcbrd8Pp+2bdvmtJ89e1bl5eUaOnSo0tLSVFpaqpaWlrBtNDc3q6SkRIMGDVJmZqaWL19+WQ9bAhBfjDHyJH2kKe7tGpQQ+P9BxShBPRqY0K6CIW9pWPJf5BIzLUA8i+g5KCNGjNCTTz6pMWPGyBijF198UXfddZf279+vm2++WcuWLdNvfvMbbd68WR6PRxUVFbr77rv1+9//XtJnT24sKSmR1+vVO++8o1OnTmnBggVKTk7WE088cUUOEEBsCRmjV3/fpISEPynY80e1nhups6HBSnF9quEpJ9SW/NkfPed4EiwQ11zmMq9Iy8jI0FNPPaV77rlHw4cP18aNG3XPPfdIkt5//32NHz9e9fX1mjZtmrZt26Y77rhDJ0+eVFZWliRp/fr1WrFihU6fPq2UlJRL2mcwGJTH49EDDzxwyZ8BcHXs3r1bBw4ciHYZX2n06NH6zne+E+0ygGvKuXPntGHDBgUCAbnd7ov27fOTZHt7e7V582Z1dnbK5/OpoaFB3d3dKioqcvqMGzdOeXl5TkCpr6/XxIkTnXAiScXFxVqyZIkOHz6sr3/96xfcV1dXl7q6upz3wWBQkjR//nylpaX19RAAXAGdnZ0xEVBuuOEGLVy4MNplANeUjo4Obdiw4ZL6RhxQDh48KJ/Pp7NnzyotLU1btmxRfn6+GhsblZKSovT09LD+WVlZ8vv9kiS/3x8WTs63n2/7MtXV1Vq9evUX1k+ZMuUrExiAq8vr9Ua7hEsydOhQ3XrrrdEuA7imnJ9guBQR38UzduxYNTY2as+ePVqyZInKysp05MiRSDcTkaqqKgUCAWc5ceLEFd0fAACIrohnUFJSUjR69GhJ0uTJk7Vv3z49++yzuvfee3Xu3Dm1tbWFzaK0tLQ4f1F5vV7t3bs3bHvn7/K52F9dqampSk3lsdYAAFwrLvs5KKFQSF1dXZo8ebKSk5NVW1vrtDU1Nam5uVk+n0+S5PP5dPDgQbW2tjp9ampq5Ha7lZ+ff7mlAACAOBHRDEpVVZVmz56tvLw8tbe3a+PGjdq5c6feeOMNeTweLVy4UJWVlcrIyJDb7dZDDz0kn8+nadOmSZJmzpyp/Px8zZ8/X2vWrJHf79fKlStVXl7ODAkAAHBEFFBaW1u1YMECnTp1Sh6PRwUFBXrjjTf03e9+V5L09NNPKyEhQaWlperq6lJxcbGef/555/OJiYnaunWrlixZIp/Pp8GDB6usrEyPP/54/x4VAACIaREFlJ///OcXbR8wYIDWrl2rtWvXfmmfkSNH6vXXX49ktwAA4BrDb/EAAADrEFAAAIB1CCgAAMA6BBQAAGCdPv8WDwBcyIQJEzR37txol/GVpkyZEu0SAFzEZf+acTSc/zXjS/k1RAAAYIdIvr85xQMAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFgnooCybt06FRQUyO12y+12y+fzadu2bU779OnT5XK5wpYHH3wwbBvNzc0qKSnRoEGDlJmZqeXLl6unp6d/jgYAAMSFpEg6jxgxQk8++aTGjBkjY4xefPFF3XXXXdq/f79uvvlmSdKiRYv0+OOPO58ZNGiQ87q3t1clJSXyer165513dOrUKS1YsEDJycl64okn+umQAABArHMZY8zlbCAjI0NPPfWUFi5cqOnTp+uWW27RM888c8G+27Zt0x133KGTJ08qKytLkrR+/XqtWLFCp0+fVkpKyiXtMxgMyuPxKBAIyO12X075AADgKonk+7vP16D09vZq06ZN6uzslM/nc9a/9NJLGjZsmCZMmKCqqip98sknTlt9fb0mTpzohBNJKi4uVjAY1OHDh790X11dXQoGg2ELAACIXxGd4pGkgwcPyufz6ezZs0pLS9OWLVuUn58vSfre976nkSNHKicnRwcOHNCKFSvU1NSkl19+WZLk9/vDwokk573f7//SfVZXV2v16tWRlgoAAGJUxAFl7NixamxsVCAQ0K9//WuVlZWprq5O+fn5Wrx4sdNv4sSJys7O1owZM3Ts2DHdeOONfS6yqqpKlZWVzvtgMKjc3Nw+bw8AANgt4lM8KSkpGj16tCZPnqzq6mpNmjRJzz777AX7FhYWSpKOHj0qSfJ6vWppaQnrc/691+v90n2mpqY6dw6dXwAAQPy67OeghEIhdXV1XbCtsbFRkpSdnS1J8vl8OnjwoFpbW50+NTU1crvdzmkiAACAiE7xVFVVafbs2crLy1N7e7s2btyonTt36o033tCxY8e0ceNGzZkzR0OHDtWBAwe0bNky3X777SooKJAkzZw5U/n5+Zo/f77WrFkjv9+vlStXqry8XKmpqVfkAAEAQOyJKKC0trZqwYIFOnXqlDwejwoKCvTGG2/ou9/9rk6cOKE333xTzzzzjDo7O5Wbm6vS0lKtXLnS+XxiYqK2bt2qJUuWyOfzafDgwSorKwt7bgoAAMBlPwclGngOCgAAseeqPAcFAADgSiGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWIaAAAADrEFAAAIB1CCgAAMA6BBQAAGAdAgoAALAOAQUAAFiHgAIAAKxDQAEAANYhoAAAAOsQUAAAgHUIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwDgEFAABYh4ACAACsQ0ABAADWSYp2AX1hjJEkBYPBKFcCAAAu1fnv7fPf4xcTkwGlvb1dkpSbmxvlSgAAQKTa29vl8Xgu2sdlLiXGWCYUCqmpqUn5+fk6ceKE3G53tEuKWcFgULm5uYxjP2As+w9j2T8Yx/7DWPYPY4za29uVk5OjhISLX2USkzMoCQkJ+trXviZJcrvd/GPpB4xj/2Es+w9j2T8Yx/7DWF6+r5o5OY+LZAEAgHUIKAAAwDoxG1BSU1O1atUqpaamRruUmMY49h/Gsv8wlv2Dcew/jOXVF5MXyQIAgPgWszMoAAAgfhFQAACAdQgoAADAOgQUAABgnZgMKGvXrtX111+vAQMGqLCwUHv37o12SdbZtWuX7rzzTuXk5MjlcumVV14JazfG6NFHH1V2drYGDhyooqIiffDBB2F9zpw5o3nz5sntdis9PV0LFy5UR0fHVTyK6KuurtbUqVM1ZMgQZWZmau7cuWpqagrrc/bsWZWXl2vo0KFKS0tTaWmpWlpawvo0NzerpKREgwYNUmZmppYvX66enp6reShRtW7dOhUUFDgPufL5fNq2bZvTzhj23ZNPPimXy6WlS5c66xjPS/PYY4/J5XKFLePGjXPaGccoMzFm06ZNJiUlxfznf/6nOXz4sFm0aJFJT083LS0t0S7NKq+//rr5p3/6J/Pyyy8bSWbLli1h7U8++aTxeDzmlVdeMX/84x/N3/zN35hRo0aZTz/91Okza9YsM2nSJLN7927zu9/9zowePdrcf//9V/lIoqu4uNi88MIL5tChQ6axsdHMmTPH5OXlmY6ODqfPgw8+aHJzc01tba159913zbRp08xf//VfO+09PT1mwoQJpqioyOzfv9+8/vrrZtiwYaaqqioahxQV//3f/21+85vfmD/96U+mqanJ/PjHPzbJycnm0KFDxhjGsK/27t1rrr/+elNQUGB+8IMfOOsZz0uzatUqc/PNN5tTp045y+nTp512xjG6Yi6g3Hrrraa8vNx539vba3Jyckx1dXUUq7Lb5wNKKBQyXq/XPPXUU866trY2k5qaan7xi18YY4w5cuSIkWT27dvn9Nm2bZtxuVzmL3/5y1Wr3Tatra1GkqmrqzPGfDZuycnJZvPmzU6f9957z0gy9fX1xpjPwmJCQoLx+/1On3Xr1hm32226urqu7gFY5LrrrjP/8R//wRj2UXt7uxkzZoypqakx3/rWt5yAwnheulWrVplJkyZdsI1xjL6YOsVz7tw5NTQ0qKioyFmXkJCgoqIi1dfXR7Gy2HL8+HH5/f6wcfR4PCosLHTGsb6+Xunp6ZoyZYrTp6ioSAkJCdqzZ89Vr9kWgUBAkpSRkSFJamhoUHd3d9hYjhs3Tnl5eWFjOXHiRGVlZTl9iouLFQwGdfjw4atYvR16e3u1adMmdXZ2yufzMYZ9VF5erpKSkrBxk/g3GakPPvhAOTk5uuGGGzRv3jw1NzdLYhxtEFM/FvjRRx+pt7c37B+DJGVlZen999+PUlWxx+/3S9IFx/F8m9/vV2ZmZlh7UlKSMjIynD7XmlAopKVLl+q2227ThAkTJH02TikpKUpPTw/r+/mxvNBYn2+7Vhw8eFA+n09nz55VWlqatmzZovz8fDU2NjKGEdq0aZP+8Ic/aN++fV9o49/kpSssLNSGDRs0duxYnTp1SqtXr9Y3v/lNHTp0iHG0QEwFFCCaysvLdejQIb399tvRLiUmjR07Vo2NjQoEAvr1r3+tsrIy1dXVRbusmHPixAn94Ac/UE1NjQYMGBDtcmLa7NmzndcFBQUqLCzUyJEj9atf/UoDBw6MYmWQYuwunmHDhikxMfELV1G3tLTI6/VGqarYc36sLjaOXq9Xra2tYe09PT06c+bMNTnWFRUV2rp1q9566y2NGDHCWe/1enXu3Dm1tbWF9f/8WF5orM+3XStSUlI0evRoTZ48WdXV1Zo0aZKeffZZxjBCDQ0Nam1t1V/91V8pKSlJSUlJqqur03PPPaekpCRlZWUxnn2Unp6um266SUePHuXfpQViKqCkpKRo8uTJqq2tddaFQiHV1tbK5/NFsbLYMmrUKHm93rBxDAaD2rNnjzOOPp9PbW1tamhocPrs2LFDoVBIhYWFV73maDHGqKKiQlu2bNGOHTs0atSosPbJkycrOTk5bCybmprU3NwcNpYHDx4MC3w1NTVyu93Kz8+/OgdioVAopK6uLsYwQjNmzNDBgwfV2NjoLFOmTNG8efOc14xn33R0dOjYsWPKzs7m36UNon2VbqQ2bdpkUlNTzYYNG8yRI0fM4sWLTXp6ethV1PjsCv/9+/eb/fv3G0nm3/7t38z+/fvNn//8Z2PMZ7cZp6enm1dffdUcOHDA3HXXXRe8zfjrX/+62bNnj3n77bfNmDFjrrnbjJcsWWI8Ho/ZuXNn2K2In3zyidPnwQcfNHl5eWbHjh3m3XffNT6fz/h8Pqf9/K2IM2fONI2NjWb79u1m+PDh19StiI888oipq6szx48fNwcOHDCPPPKIcblc5re//a0xhjG8XP/7Lh5jGM9L9fDDD5udO3ea48ePm9///vemqKjIDBs2zLS2thpjGMdoi7mAYowxP/vZz0xeXp5JSUkxt956q9m9e3e0S7LOW2+9ZSR9YSkrKzPGfHar8U9+8hOTlZVlUlNTzYwZM0xTU1PYNj7++GNz//33m7S0NON2u833v/99097eHoWjiZ4LjaEk88ILLzh9Pv30U/OP//iP5rrrrjODBg0yf/u3f2tOnToVtp3/+Z//MbNnzzYDBw40w4YNMw8//LDp7u6+ykcTPf/wD/9gRo4caVJSUszw4cPNjBkznHBiDGN4uT4fUBjPS3Pvvfea7Oxsk5KSYr72ta+Ze++91xw9etRpZxyjy2WMMdGZuwEAALiwmLoGBQAAXBsIKAAAwDoEFAAAYB0CCgAAsA4BBQAAWIeAAgAArENAAQAA1iGgAAAA6xBQAACAdQgoAADAOgQUAABgHQIKAACwzv8DOZx9PeTUlsMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第7章-DQN算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
